package cn.myframe.utils.batch;


import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.sql.DataSource;
import java.sql.*;
import java.util.List;

/**
 * 大数据导入工具类
 */
public class JDBCBatch {
    private static Logger logger = LoggerFactory.getLogger(JDBCBatch.class);

    private JDBCBatch() {
    }

    private volatile static DataSource dataSource;

    public static void init(DataSource dataSource0) {
        if (dataSource == null) {
            dataSource = dataSource0;
        }
    }

    public static int insertBatch(List<Object[]> list,int sqlCount, String... sql) throws SQLException {

        AbstractBatch abstractBatch = null;
        try {
            if (sql != null && sql.length == 1)
                abstractBatch = new BatchNoGroup();
            else if (sql.length == 2)
                abstractBatch = new BatchGroup();
            else
                abstractBatch = new BatchGroup2();
            return abstractBatch.insertBatch(list,sqlCount, sql);
        } catch (SQLException e) {
            throw new SQLException(e);
        } finally {
            try {
                abstractBatch.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }

        }
    }

    public static void insertBatchIn(List<Object[]> list,int sqlCount, String... sql) throws SQLException {
        AbstractBatch abstractBatch = new BatchNoGroupIn();
        try {
            abstractBatch.insertBatch(list,sqlCount, sql);
        } catch (SQLException e) {
            throw new SQLException(e);
        } finally {
            try {
                abstractBatch.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }


    static abstract class AbstractBatch {
        Connection connection = null;

        /**
         * insert into table values(?,?,?)
         */
        PreparedStatement pStatement = null;

        /**
         * insert into group_table select id ,groupId from table where field = ?
         */
        PreparedStatement pStatement2 = null;

        /**
         * insert into table select id ,(select id from table where field = ? )
         * select id from table where field in ()
         */
        PreparedStatement pStatement3 = null;
        Statement statement = null;

        /**
         * @return
         * @throws SQLException
         */
        Connection getConnection() throws SQLException {
            return JDBCBatch.dataSource.getConnection();
        }

        /**
         * @throws SQLException
         */
        void close() throws SQLException {
            if (statement != null) {
                statement.close();
            }
            if (pStatement != null) {
                pStatement.close();
            }
            if (pStatement2 != null) {
                pStatement2.close();
            }
            if (pStatement3 != null) {
                pStatement3.close();
            }
            if (connection != null) {
                connection.setAutoCommit(Boolean.TRUE);
                connection.close();
            }
        }

        /**
         * 批量插入
         *
         * @param list 占位符 后一个占位符用于检索id
         * @param sql
         */
        abstract int insertBatch(List<Object[]> list,int sqlCount,String... sql) throws SQLException;
    }

    static class BatchNoGroup extends AbstractBatch {
        /**
         * list -> object[] {?,?,?}对应占位符
         *
         * @param list 占位符
         * @param sql
         * @return
         * @throws SQLException
         */
        @Override
        int insertBatch(List<Object[]> list,int sqlCount, String... sql) throws SQLException {
          //  Connection conn;
            int count = 0;
            String batchSql = sql[0];
            connection = getConnection();
            connection.setAutoCommit(Boolean.FALSE);
            pStatement = connection.prepareStatement(batchSql, PreparedStatement.RETURN_GENERATED_KEYS);
            long start = System.currentTimeMillis();
            for (int i = 0; i < list.size(); i++) {
                for (int j = 1, k = 0; k < list.get(0).length - 1; j++, k++) {//
                    Object obj = list.get(i)[k];
                    if (obj instanceof Integer) {
                        pStatement.setInt(j, (Integer) obj);
                    } else if (obj instanceof String) {
                        pStatement.setString(j, (String) obj);
                    } else if (obj instanceof Long) {
                        pStatement.setLong(j, (Long) obj);
                    } else if (obj == null) {
                        pStatement.setNull(j, java.sql.Types.INTEGER);
                    }
                }
                pStatement.addBatch();
                if (i % sqlCount == 0 && i != 0) {
                    pStatement.executeBatch();
                    ResultSet result = pStatement.getGeneratedKeys();
                    while (result.next()) {
                        count++;
                    }
                    connection.commit();
                    System.out.println("insert rows :" + i + ".....");
                }
            }
            pStatement.executeBatch();
            ResultSet result = pStatement.getGeneratedKeys();
            while (result.next()) {
                count++;
            }
            connection.commit();  // 提交
            logger.info("============================= 插入" + count + "条数据用了" + (System.currentTimeMillis() - start) / 1000 + "秒 =============================");
            return count;
        }
    }

    static class BatchGroup extends AbstractBatch {
        /**
         * object[] {?,?,?}对应占位符 最后一个占位符用于检索分组id
         *
         * @param list 占位符
         * @param sql
         * @return
         * @throws SQLException
         */
        @Override
        int insertBatch(List<Object[]> list,int sqlCount ,String... sql) throws SQLException {
          //  Connection conn;
            int count = 0;
            String batchSql0 = sql[0];
            String batchSql1 = sql[1];
            connection = getConnection();
            connection.setAutoCommit(Boolean.FALSE);
            pStatement = connection.prepareStatement(batchSql0, PreparedStatement.RETURN_GENERATED_KEYS);
            pStatement2 = connection.prepareStatement(batchSql1, PreparedStatement.RETURN_GENERATED_KEYS);//分组sql
            long start = System.currentTimeMillis();
            int length;
            for (int i = 0; i < list.size(); i++) {
                for (int j = 1, k = 0; j < (length = list.get(0).length) + 1; j++, k++) {//替换占位符
                    Object obj = list.get(i)[k];
                    if (k < length - 1) {
                        if (obj instanceof Integer) {
                            pStatement.setInt(j, (Integer) obj);
                        } else if (obj instanceof String) {
                            pStatement.setString(j, (String) obj);
                        } else if (obj instanceof Long) {
                            pStatement.setLong(j, (Long) obj);
                        } else if (obj == null) {
                            pStatement.setNull(j, java.sql.Types.INTEGER);
                        }
                    } else {//替换分组占位符
                        Object o = list.get(i)[k];
                        String[] split = o.toString().split("\\|\\|");
                        for (int n = 1, m = 0; m < split.length; n++, m++) {
                            pStatement2.setString(n, split[m]);
                        }
                    }
                }
                pStatement.addBatch();
                pStatement2.addBatch();
                if (i % sqlCount == 0 && i != 0) {
                    pStatement.executeBatch();
                    ResultSet result = pStatement.getGeneratedKeys();
                    while (result.next()) {
                        count++;
                    }
                    connection.commit();
                    pStatement2.executeBatch();
                    ResultSet result2 = pStatement2.getGeneratedKeys();
                    while (result.next()) {
                        count++;
                    }
                    connection.commit();
                    System.out.println("insert rows :" + i + ".....");
                }
            }
            pStatement.executeBatch();
            ResultSet result = pStatement.getGeneratedKeys();
            while (result.next()) {
                count++;
            }
            connection.commit();  // 提交
            pStatement2.executeBatch();
            ResultSet result2 = pStatement2.getGeneratedKeys();
            while (result.next()) {
                count++;
            }
            connection.commit();
            logger.info("============================= 插入" + count + "条数据用了" + (System.currentTimeMillis() - start) / 1000 + "秒 =============================");
            return count;
        }
    }

    /**
     * 采用in的方法
     */
    static class BatchNoGroupIn extends AbstractBatch {

        @Override
        int insertBatch(List<Object[]> list,int sqlCount, String... sql) throws SQLException {
           // Connection conn;
            String batchSql = sql[0];
            int count = 0;
            connection = getConnection();
            connection.setAutoCommit(Boolean.FALSE);
            statement = connection.createStatement();
            long start = System.currentTimeMillis();
            int length;
            boolean isLoad = Boolean.FALSE;
            for (int i = 0; i < list.size(); i++) {
                length = list.get(i).length;
                if (length > 0) {
                    for (int j = 0; j < length; j++) {
                        Object obj = list.get(i)[j];
                        if (j == 0) {
                            batchSql = batchSql.replaceFirst("\\?", "'" + obj.toString() + "'");
                            continue;
                        }
                        String splitStr = (String) obj;
                        if (StringUtils.isNotEmpty(splitStr)) {
                            isLoad = Boolean.TRUE;
                            String[] fea = splitStr.split("\\+");//feature : a+b
                            for (String s : fea) {
                                statement.addBatch(batchSql.replace("?", "'" + s + "'"));
                            }
                        }
                    }
                }
                if (i % sqlCount == 0 && i != 0) {
                    if (isLoad) {
                        int[] c = statement.executeBatch();
                        for (int k = 0; k < c.length; k++) {
                            if (c[k] == 1) {
                                count++;
                            }
                        }
                        connection.commit();
                        System.out.println("insert 10000 .....");
                    }
                }
            }
            if (isLoad) {
                int[] c = statement.executeBatch();
                for (int k = 0; k < c.length; k++) {
                    if (c[k] == 1) {
                        count++;
                    }
                }
                connection.commit();  // 提交
                logger.info("============================= 插入" + count + "条数据用了" + (System.currentTimeMillis() - start) / 1000 + "秒 =============================");
            }
            return count;
        }
    }

    private static class BatchGroup2 extends AbstractBatch {
        @Override
        int insertBatch(List<Object[]> list,int sqlCount, String... sql) throws SQLException {
       //     Connection conn;
            int count = 0;
            int[] c;
            boolean s1 = false, s2 = false;
            String batchSql0 = sql[0];
            String batchSql1 = sql[1];
            String batchSql2 = sql[2];
            connection = getConnection();
            connection.setAutoCommit(Boolean.FALSE);
            pStatement = connection.prepareStatement(batchSql0);
            pStatement2 = connection.prepareStatement(batchSql1);// null
            if (StringUtils.isNotEmpty(batchSql2))
                pStatement3 = connection.prepareStatement(batchSql2);//分组sql
            long start = System.currentTimeMillis();
            int length;
            for (int i = 0; i < list.size(); i++) {
                for (int j = 1, k = 0; j < (length = list.get(0).length) + 1; j++, k++) {//替换占位符
                    Object obj = list.get(i)[k];
                    if (k < length - 1) {
                        if (StringUtils.isNotEmpty((String) list.get(i)[length - 2])) {
                            s1 = Boolean.TRUE;
                            if (obj instanceof Integer) {
                                pStatement.setInt(j, (Integer) obj);
                            } else if (obj instanceof String) {
                                pStatement.setString(j, (String) obj);
                            } else if (obj instanceof Long) {
                                pStatement.setLong(j, (Long) obj);
                            }
                        } else {
                            s2 = Boolean.TRUE;
                            if (obj instanceof Integer) {
                                pStatement2.setInt(j, (Integer) obj);
                            } else if (obj instanceof String) {
                                pStatement2.setString(j, (String) obj);
                            } else if (obj instanceof Long) {
                                pStatement2.setLong(j, (Long) obj);
                            }
                        }
                    } else {//group
                        if (StringUtils.isNotEmpty(batchSql2)) {
                            Object o = list.get(i)[k];
                            String[] split = o.toString().split("\\|\\|");
                            for (int n = 1, m = 0; m < split.length; n++, m++) {
                                pStatement3.setString(n, split[m]);
                            }
                        }
                    }
                }
                if (s1)
                    pStatement.addBatch();
                if (s2)
                    pStatement2.addBatch();
                if (StringUtils.isNotEmpty(batchSql2))
                    pStatement3.addBatch();
                if (i % sqlCount == 0 && i != 0) {
                    if (s1) {//去空
                        c = pStatement.executeBatch();
                        for (int k = 0; k < c.length; k++) {
                            if (c[k] == -2) {
                                count++;
                            }
                        }
                        connection.commit();
                    }
                    if (s2) {//去空
                        c = pStatement2.executeBatch();
                        for (int k = 0; k < c.length; k++) {
                            if (c[k] == -2) {
                                count++;
                            }
                        }
                        connection.commit();
                    }
                    if (StringUtils.isNotEmpty(batchSql2)) {
                        c = pStatement3.executeBatch();
                        for (int k = 0; k < c.length; k++) {
                            if (c[k] == -2) {
                                count++;
                            }
                        }
                    }
                    connection.commit();
                    System.out.println("insert rows :" + i + ".....");
                }
            }
            if (s1) {//去空
                c = pStatement.executeBatch();
                for (int k = 0; k < c.length; k++) {
                    if (c[k] == -2) {
                        count++;
                    }
                }
                connection.commit();  // 提交
            }
            if (s2) {//去空
                c = pStatement2.executeBatch();
                for (int k = 0; k < c.length; k++) {
                    if (c[k] == -2) {
                        count++;
                    }
                }
                connection.commit();  // 提交
            }
            if (StringUtils.isNotEmpty(batchSql2)) {
                c = pStatement3.executeBatch();
                for (int k = 0; k < c.length; k++) {
                    if (c[k] == -2) {
                        count++;
                    }
                }
                connection.commit();
            }
            logger.info("============================= 插入" + count + "条数据用了" + (System.currentTimeMillis() - start) / 1000 + "秒 =============================");
            return count;
        }
    }
}
